import torch.nn as nn


# TODO 补充其他损失函数
class ViTLoss(nn.Module):
    def __init__(self):
        super(ViTLoss, self).__init__()

        self.cross_loss = nn.CrossEntropyLoss()

    def forward(self, ):
        loss = self.cross_loss
        return loss
